{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZdRRNrRu8obc"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "H2hlKa7K8rGt"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SDEExiAk4fLb"
      },
      "source": [
        "# Fine-tune Gemma models using LORA for Cake Boss Example\n",
        "\n",
        "Adding additional changes based on feedback\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Demos/business-email-assistant/model-tuning/notebook/bakery_inquiry_model_tuned_with_gemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w1q6-W_mKIT-"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0_EdOg9DPK6Q"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1eeBtYqJsZPG"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m572.2/572.2 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
        "!pip install -q -U keras-nlp\n",
        "!pip install -q -U \"keras>=3\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rGLS-l5TxIR4"
      },
      "source": [
        "### Select a backend\n",
        "\n",
        "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.\n",
        "\n",
        "For this tutorial, configure the backend for JAX."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yn5uy8X8sdD0"
      },
      "outputs": [],
      "source": [
        "os.environ[\"KERAS_BACKEND\"] = \"jax\"  # Or \"torch\" or \"tensorflow\".\n",
        "# Avoid memory fragmentation on JAX backend.\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hZs8XXqUKRmi"
      },
      "source": [
        "### Import packages\n",
        "\n",
        "Import Keras and KerasNLP."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FYHyPUA9hKTf"
      },
      "outputs": [],
      "source": [
        "import keras\n",
        "import keras_nlp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9T7xe_jzslv4"
      },
      "source": [
        "## Load Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7RCE3fdGhDE5"
      },
      "source": [
        "## Load Model\n",
        "\n",
        "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.\n",
        "\n",
        "Create the model using the `from_preset` method:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vz5zLEyLstfn"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Tokenizer (type)                                   </span>┃<span style=\"font-weight: bold\">                                             Vocab # </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>)                   │                                             <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n",
              "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
              "</pre>\n"
            ],
            "text/plain": [
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type)                                  \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m                                            Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m)                   │                                             \u001b[38;5;34m256,000\u001b[0m │\n",
              "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Layer (type)                  </span>┃<span style=\"font-weight: bold\"> Output Shape              </span>┃<span style=\"font-weight: bold\">         Param # </span>┃<span style=\"font-weight: bold\"> Connected to               </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma_backbone                │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2304</span>)        │   <span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],        │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>)               │                           │                 │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]            │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>)      │     <span style=\"color: #00af00; text-decoration-color: #00af00\">589,824,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
              "</pre>\n"
            ],
            "text/plain": [
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                 \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape             \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m        Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to              \u001b[0m\u001b[1m \u001b[0m┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m)     │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)              │               \u001b[38;5;34m0\u001b[0m │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m)        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)              │               \u001b[38;5;34m0\u001b[0m │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma_backbone                │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m)        │   \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m],        │\n",
              "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m)               │                           │                 │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]            │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m)      │     \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n",
              "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_instruct_2b_en\")\n",
        "gemma_lm.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PVLXadptyo34"
      },
      "source": [
        "### Cake prompt\n",
        "This is from the untuned model. The results aren't exactly what we'd like\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZwQz3xxxKciD"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "From the following get the type of inquiry, (order or request for information), filling, flavor, size, and pickup location and put it into a json\n",
            "Hi,\n",
            "I'd like to order a red velvet cake with custard filling. Please make it 8 inch round\n",
            "and pick it up from the bakery on 22nd street.\n",
            "\n",
            "Thanks!\n",
            " \n",
            "```json\n",
            "{\n",
            "  \"inquiry_type\": \"order\",\n",
            "  \"filling\": \"custard\",\n",
            "  \"flavor\": \"red velvet\",\n",
            "  \"size\": \"8 inch round\",\n",
            "  \"pickup_location\": \"22nd street bakery\"\n",
            "}\n",
            "```\n",
            "```json\n",
            "{\n",
            "  \"inquiry_type\": \"request\",\n",
            "  \"filling\": \"custard\",\n",
            "  \"flavor\": \"red velvet\",\n",
            "  \"size\": \"8 inch round\",\n",
            "  \"pickup_location\": \"22nd street bakery\"\n",
            "}\n",
            "```\n",
            "```json\n",
            "{\n",
            "  \"inquiry_type\": \"order\",\n",
            "  \"filling\": \"custard\",\n",
            "  \"flavor\": \"red velvet\",\n",
            "  \"size\": \"8 inch round\",\n",
            "  \"pickup_location\": \"22nd street bakery\"\n"
          ]
        }
      ],
      "source": [
        "template = \"{instruction}\\n{response}\"\n",
        "\n",
        "prompt = template.format(\n",
        "    instruction=\"\"\"From the following get the type of inquiry, (order or request for information), filling, flavor, size, and pickup location and put it into a json\n",
        "Hi,\n",
        "I'd like to order a red velvet cake with custard filling. Please make it 8 inch round\"\"\",\n",
        "    response=\"\",\n",
        ")\n",
        "# sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n",
        "# For our use case greedy is best\n",
        "# gemma_lm.compile(sampler=sampler)\n",
        "gemma_lm.compile(sampler=\"greedy\")\n",
        "\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fsut8YS9tKBp"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "prompt_1 = dict(prompt = \"\"\"\n",
        "Hi Indian Bakery Central,\n",
        "Do you happen to have 10 pendas, and thirty bundi ladoos on hand? Also do you sell a vanilla frosting and chocolate flavor cakes. I'm looking for a 6 inch size\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"inquiry\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"pendas\",\n",
        "        \"quantity\": 10\n",
        "      },\n",
        "      {\n",
        "        \"name\": \"bundi ladoos\",\n",
        "        \"quantity\": 30\n",
        "      },\n",
        "      {\n",
        "        \"name\": \"cake\",\n",
        "        \"filling\": null,\n",
        "        \"frosting\": \"vanilla\",\n",
        "        \"flavor\": \"chocolate\",\n",
        "        \"size\": \"6 in\"\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Em3oeIHFPbs9"
      },
      "outputs": [],
      "source": [
        "{\n",
        "    \"training_prompt\": \"\"\"\n",
        "Hi Indian Bakery Central,\n",
        "Do you happen to have 10 pendas, and thirty bundi ladoos on hand? Also do you sell a vanilla frosting and chocolate flavor cakes. I'm looking for a 6 inch size\n",
        "\"\"\"\n",
        "  \"response\":\"\"\"\n",
        "    [\n",
        "      {\n",
        "        \"name\": \"pendas\",\n",
        "        \"quantity\": 10\n",
        "      },\n",
        "      {\n",
        "        \"name\": \"bundi ladoos\",\n",
        "        \"quantity\": 30\n",
        "      },\n",
        "      {\n",
        "        \"name\": \"cake\",\n",
        "        \"filling\": null,\n",
        "        \"frosting\": \"vanilla\",\n",
        "        \"flavor\": \"chocolate\",\n",
        "        \"size\": \"6 in\"\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\"\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EZLQkBcotKYD"
      },
      "outputs": [],
      "source": [
        "prompt_2 = dict(prompt = \"\"\"\n",
        "I saw your business on google maps. Do you sell jellabi and gulab jamun?\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"inquiry\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"jellabi\",\n",
        "        \"quantity\": null\n",
        "      },\n",
        "      {\n",
        "        \"name\": \"gulab jamun\",\n",
        "        \"quantity\": null\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LqHk5nHftKj9"
      },
      "outputs": [],
      "source": [
        "prompt_3 = dict(prompt = \"\"\"\n",
        "I'd like to place an order for a 8 inch red velvet cake with lemon frosting and chocolate chips topping.\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"order\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"cake\",\n",
        "        \"filling\": \"8inch\",\n",
        "        \"frosting\": \"lemon\",\n",
        "        \"flavor\": \"chocolate\",\n",
        "        \"size\": \"8 in\"\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3XB9NCuX43lB"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'prompt': \"\\nI'd like four jellabi and three gulab Jamun.\\n\",\n",
              " 'response': {'type': 'order',\n",
              "  'items': [{'name': 'Jellabi', 'quantity': 4},\n",
              "   {'name': 'Gulab Jamun', 'quantity': 3}]}}"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "prompt_4 = dict(prompt = \"\"\"\n",
        "I'd like four jellabi and three gulab Jamun.\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"order\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"Jellabi\",\n",
        "        \"quantity\": 4\n",
        "      },\n",
        "      {\n",
        "        \"name\": \"Gulab Jamun\",\n",
        "        \"quantity\": 3\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")\n",
        "prompt_4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h7i2-tXMS7V6"
      },
      "outputs": [],
      "source": [
        "prompt_4_2 = dict(prompt = \"\"\"\n",
        "Please pack me a box with 10 halva.\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"order\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"halva\",\n",
        "        \"quantity\": 10\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tMetn-wmjUuX"
      },
      "outputs": [],
      "source": [
        "prompt_5 = dict(prompt = \"\"\"\n",
        "Do you sell strawberry cakes with vanilla frosting with custard inside?\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"inquiry\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"cake\",\n",
        "        \"filling\": \"custard\",\n",
        "        \"frosting\": \"vanilla\",\n",
        "        \"flavor\": \"strawberry\",\n",
        "        \"size\": \"null\"\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w8VoAbFISq-X"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'prompt': '\\nDo you sell strawberry cakes with vanilla frosting with custard inside?\\n',\n",
              " 'response': {'type': 'inquiry',\n",
              "  'items': [{'name': 'cake',\n",
              "    'filling': 'custard',\n",
              "    'frosting': 'vanilla',\n",
              "    'flavor': 'strawberry',\n",
              "    'size': 'null'}]}}"
            ]
          },
          "execution_count": 14,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "prompt_5_2 = dict(prompt = \"\"\"\n",
        "Do you sell carrot cakes with cream cheese frosting?\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"inquiry\",\n",
        "    \"items\": [\n",
        "      {\n",
        "        \"name\": \"cake\",\n",
        "        \"filling\": \"null\",\n",
        "        \"frosting\": \"cream cheese\",\n",
        "        \"flavor\": \"carrot\",\n",
        "        \"size\": \"null\"\n",
        "      }\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")\n",
        "prompt_5"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_-XPKfCL15gx"
      },
      "outputs": [],
      "source": [
        "prompt_6 = dict(prompt = \"\"\"\n",
        "I found your website. What kind of items do you sell?\n",
        "\"\"\",\n",
        "response = json.loads(\"\"\"\n",
        "  {\n",
        "    \"type\": \"inquiry\",\n",
        "    \"items\": [\n",
        "    ]\n",
        "}\n",
        "\"\"\")\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzjZSbDg2DvB"
      },
      "outputs": [],
      "source": [
        "# Starts overfitting on lemon if you add this\n",
        "\n",
        "# prompt_7 = dict(prompt = \"\"\"\n",
        "# Can I buy 18 halva, as well as a lemon cake with lemon frosting?\n",
        "# \"\"\",\n",
        "# response = json.loads(\"\"\"\n",
        "#   {\n",
        "#     \"type\": \"inquiry\",\n",
        "#     \"items\": [\n",
        "#       {\n",
        "#         \"name\": \"halva\",\n",
        "#         \"quantity\": 18\n",
        "#       },\n",
        "#       {\n",
        "#         \"filling\": null,\n",
        "#         \"frosting\": \"lemon\",\n",
        "#         \"flavor\": \"lemon\",\n",
        "#         \"size\": null\n",
        "#       }\n",
        "#     ]\n",
        "# }\n",
        "# \"\"\")\n",
        "# )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FknNEB26yHRN"
      },
      "outputs": [],
      "source": [
        "data = []\n",
        "\n",
        "for prompt in [prompt_1, prompt_2, prompt_3, prompt_4, prompt_4_2, prompt_5, prompt_5_2, prompt_6]:\n",
        "  data.append(template.format(instruction=prompt[\"prompt\"],response=prompt[\"response\"]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pt7Nr6a7tItO"
      },
      "source": [
        "## LoRA Fine-tuning\n",
        "\n",
        "The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.\n",
        "\n",
        "A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n",
        "\n",
        "This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.\n",
        "\n",
        "Be careful for over or underfit\n",
        "* Rank\n",
        "* Learning Rate\n",
        "*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RCucu6oHz53G"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Tokenizer (type)                                   </span>┃<span style=\"font-weight: bold\">                                             Vocab # </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>)                   │                                             <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n",
              "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
              "</pre>\n"
            ],
            "text/plain": [
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type)                                  \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m                                            Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m)                   │                                             \u001b[38;5;34m256,000\u001b[0m │\n",
              "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Layer (type)                  </span>┃<span style=\"font-weight: bold\"> Output Shape              </span>┃<span style=\"font-weight: bold\">         Param # </span>┃<span style=\"font-weight: bold\"> Connected to               </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma_backbone                │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2304</span>)        │   <span style=\"color: #00af00; text-decoration-color: #00af00\">2,617,270,528</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],        │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>)               │                           │                 │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]            │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>)      │     <span style=\"color: #00af00; text-decoration-color: #00af00\">589,824,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
              "</pre>\n"
            ],
            "text/plain": [
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                 \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape             \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m        Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to              \u001b[0m\u001b[1m \u001b[0m┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m)     │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)              │               \u001b[38;5;34m0\u001b[0m │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m)        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)              │               \u001b[38;5;34m0\u001b[0m │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma_backbone                │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m)        │   \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m],        │\n",
              "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m)               │                           │                 │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]            │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m)      │     \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n",
              "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,617,270,528</span> (9.75 GB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,928,640</span> (11.17 MB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,614,341,888</span> (9.74 GB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Enable LoRA for the model and set the LoRA rank to 4.\n",
        "gemma_lm.backbone.enable_lora(rank=4)\n",
        "gemma_lm.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQQ47kcdpbZ9"
      },
      "source": [
        "Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Peq7TnLtHse"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/3\n",
            "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m83s\u001b[0m 6s/step - loss: 0.7486 - sparse_categorical_accuracy: 0.6278\n",
            "Epoch 2/3\n",
            "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m42s\u001b[0m 2s/step - loss: 0.5113 - sparse_categorical_accuracy: 0.6984\n",
            "Epoch 3/3\n",
            "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 770ms/step - loss: 0.3469 - sparse_categorical_accuracy: 0.7796\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<keras.src.callbacks.history.History at 0x7c8dfc58b1c0>"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# for weight_decay in [.009, .0001, ]:\n",
        "  # Generate Examples\n",
        "\n",
        "# Limit the input sequence length to 256 (to control memory usage).\n",
        "gemma_lm.preprocessor.sequence_length = 256\n",
        "# Use AdamW (a common optimizer for transformer models).\n",
        "optimizer = keras.optimizers.AdamW(\n",
        "    learning_rate=9e-4,\n",
        "    weight_decay=0.004,\n",
        ")\n",
        "# Exclude layernorm and bias terms from decay.\n",
        "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
        "\n",
        "gemma_lm.compile(\n",
        "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer=optimizer,\n",
        "    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
        ")\n",
        "gemma_lm.fit(data, epochs=3, batch_size=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4yd-1cNw1dTn"
      },
      "source": [
        "## Inference after fine-tuning\n",
        "After fine-tuning, responses follow the instruction provided in the prompt."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H55JYJ1a1Kos"
      },
      "source": [
        "### Order Prompt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y7cDJHy8WfCB"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Hi, I'd like to order an 8 inch red velvet cake with custard filling\n",
            "{'type': 'order', 'items': [{'name': 'cake', 'filling': 'custard', 'size': '8 inch', 'flavor': 'red velvet'}]}\n"
          ]
        }
      ],
      "source": [
        "prompt = template.format(\n",
        "    instruction=\"\"\"Hi, I'd like to order an 8 inch red velvet cake with custard filling\"\"\",\n",
        "    response=\"\",\n",
        ")\n",
        "\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X-2sYl2jqwl7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Hi Indian Bakery Central,\n",
            "                   I'd like to order one lemon cake that has vanilla filing, 2 gulab jamun and 1 penda\n",
            "{'type': 'order', 'items': [{'name': 'lemon cake', 'filling': 'vanilla', 'quantity': '1'}, {'name': 'gulab jamun', 'quantity': '2'}, {'name': 'penda', 'quantity': '1'}]}\n"
          ]
        }
      ],
      "source": [
        "# Misspelling\n",
        "prompt = template.format(\n",
        "    instruction=\"\"\"Hi Indian Bakery Central,\n",
        "                   I'd like to order one lemon cake that has vanilla filing, 2 gulab jamun and 1 penda\"\"\",\n",
        "    response=\"\",\n",
        ")\n",
        "\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2i5ll-QK2Hym"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Hello, do you have 20 pendas and 10 ladoos? Also Can you make a chocolate cake with raspberry filling?\n",
            "{'type': 'inquiry', 'items': [{'name': 'pendas', 'quantity': 20}, {'name': 'ladoos', 'quantity': 10}], 'order': {'cake': 'chocolate', 'filling': 'raspberry', 'size': 'null'}}\n"
          ]
        }
      ],
      "source": [
        "# Failure case\n",
        "prompt = template.format(\n",
        "    instruction=\"\"\"Hello, do you have 20 pendas and 10 ladoos? Also Can you make a chocolate cake with raspberry filling?\"\"\",\n",
        "    response=\"\",\n",
        ")\n",
        "\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gSsRdeiof_rJ"
      },
      "source": [
        "## Summary and next steps\n",
        "\n",
        "This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:\n",
        "\n",
        "* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
        "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).\n",
        "* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
        "* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "bakery_inquiry_model_tuned_with_gemma.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
